{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## Exploring simple optimizations for Stable Diffusion XL"
      ],
      "metadata": {
        "id": "VpOioQ7TfxOz"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi"
      ],
      "metadata": {
        "id": "0c6FCLrjKR5k"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WxpxpPruKD6o"
      },
      "outputs": [],
      "source": [
        "!pip install git+https://github.com/huggingface/diffusers -q\n",
        "!pip install transformers accelerate -q"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Unoptimized setup\n",
        "\n",
        "* FP32 computation\n",
        "* Default attention processor"
      ],
      "metadata": {
        "id": "zK5N3pJHfZX9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from diffusers import StableDiffusionXLPipeline\n",
        "\n",
        "pipe = StableDiffusionXLPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\")\n",
        "pipe = pipe.to(\"cuda\")\n",
        "pipe.unet.set_default_attn_processor()"
      ],
      "metadata": {
        "id": "H6ipTyxQKML3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "import torch\n",
        "\n",
        "num_iterations = 3\n",
        "num_inference_steps = 25\n",
        "prompt = \"a photo of an astronaut riding a horse on mars\"\n",
        "num_images_per_prompt = 4\n",
        "\n",
        "\n",
        "def bytes_to_giga_bytes(bytes):\n",
        "    return bytes / 1024 / 1024 / 1024\n",
        "\n",
        "\n",
        "def timeit(\n",
        "    pipeline,\n",
        "    prompt_embeds=None,\n",
        "    negative_prompt_embeds=None,\n",
        "    pooled_prompt_embeds=None,\n",
        "    negative_pooled_prompt_embeds=None,\n",
        "):\n",
        "    if prompt_embeds is None:\n",
        "        call_args = dict(\n",
        "            prompt=prompt,\n",
        "            num_images_per_prompt=num_images_per_prompt,\n",
        "            num_inference_steps=num_inference_steps,\n",
        "        )\n",
        "    else:\n",
        "        call_args = dict(\n",
        "            prompt_embeds=prompt_embeds,\n",
        "            negative_prompt_embeds=negative_prompt_embeds,\n",
        "            pooled_prompt_embeds=pooled_prompt_embeds,\n",
        "            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n",
        "            num_images_per_prompt=num_images_per_prompt,\n",
        "            num_inference_steps=num_inference_steps,\n",
        "        )\n",
        "    for i in range(num_iterations):\n",
        "        start = time.time_ns()\n",
        "        _ = pipeline(**call_args)\n",
        "        end = time.time_ns()\n",
        "        if i == num_iterations - 1:\n",
        "            print(f\"Execution time -- {(end - start) / 1e6:.1f} ms\\n\")\n",
        "    print(\n",
        "        f\"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB\"\n",
        "    )"
      ],
      "metadata": {
        "id": "FlPxa9diNVsx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "pN-9CQuXQBGV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import gc\n",
        "import torch\n",
        "\n",
        "def flush():\n",
        "  gc.collect()\n",
        "  torch.cuda.empty_cache()\n",
        "  torch.cuda.reset_peak_memory_stats()\n",
        "\n",
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "Yx1GizAMSLAg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Just FP16"
      ],
      "metadata": {
        "id": "05Q6wQ9Ffh_Q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe = pipe.to(\"cuda\")\n",
        "pipe.unet.set_default_attn_processor()\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "CSXXwbNJSQbH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "hTiHMZBGULvF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## FP16 + SDPA"
      ],
      "metadata": {
        "id": "CdgoEbJ2fqT4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe = pipe.to(\"cuda\")\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "2jExLo7IXert"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "L_lBhfSdYTOd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "From here on, we refer to \"FP16 + SDPA\" as the default setting."
      ],
      "metadata": {
        "id": "r-uc8PqtU3Vj"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Default + `torch.compile()`"
      ],
      "metadata": {
        "id": "nBHYf6WTf5D-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe = pipe.to(\"cuda\")\n",
        "pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "iQ9V76aeYqjU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "W1MDkID1a_eU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Default + Model CPU Offloading\n",
        "\n",
        "Here we focus more on the memory optimization rather than inference speed."
      ],
      "metadata": {
        "id": "fvIDIm-ff9AW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe.enable_model_cpu_offload()\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "baM5rtmxcXOc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "AI5V68_gdJ5T"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Default + Sequential CPU Offloading\n",
        "\n"
      ],
      "metadata": {
        "id": "HAqWhXOrgFGr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe.enable_sequential_cpu_offload()\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "kGWRyvcidKT5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "zSy6UWV6eYhS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Default + VAE Slicing\n",
        "\n",
        "Specifically suited for optimizing memory for decoding latents into higher-res images without compromising too much on the inference speed."
      ],
      "metadata": {
        "id": "7t9pbR1PgT-z"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe = pipe.to(\"cuda\")\n",
        "pipe.enable_vae_slicing()\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "OcIeNb6CeZWi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "4XAnHrc4jPZS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Default + VAE Slicing + Sequential CPU Offloading\n"
      ],
      "metadata": {
        "id": "FP4Qo_Svi8i7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe.enable_sequential_cpu_offload()\n",
        "pipe.enable_vae_slicing()\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "vK49SwlTi64T"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "ZVwDZCDcjQBG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Default + Precompting text embeddings"
      ],
      "metadata": {
        "id": "qEq-4plZiCOZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer\n",
        "\n",
        "\n",
        "pipe_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
        "torch_dtype = torch.float16\n",
        "\n",
        "# Load the text encoders and tokenizers.\n",
        "text_encoder = CLIPTextModel.from_pretrained(pipe_id, subfolder=\"text_encoder\", torch_dtype=torch.float16).to(\"cuda\")\n",
        "tokenizer = CLIPTokenizer.from_pretrained(pipe_id, subfolder=\"tokenizer\")\n",
        "text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(pipe_id, subfolder=\"text_encoder_2\", torch_dtype=torch.float16).to(\"cuda\")\n",
        "tokenizer_2 = CLIPTokenizer.from_pretrained(pipe_id, subfolder=\"tokenizer_2\")"
      ],
      "metadata": {
        "id": "-ptkitpYusV9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def encode_prompt(tokenizers, text_encoders, prompt: str, negative_prompt: str = None):\n",
        "    device = text_encoders[0].device\n",
        "\n",
        "    if isinstance(prompt, str):\n",
        "        prompt = [prompt]\n",
        "    batch_size = len(prompt)\n",
        "\n",
        "    prompt_embeds_list = []\n",
        "    for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n",
        "        text_inputs = tokenizer(\n",
        "            prompt,\n",
        "            padding=\"max_length\",\n",
        "            max_length=tokenizer.model_max_length,\n",
        "            truncation=True,\n",
        "            return_tensors=\"pt\",\n",
        "        )\n",
        "\n",
        "        text_input_ids = text_inputs.input_ids\n",
        "\n",
        "        prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n",
        "        pooled_prompt_embeds = prompt_embeds[0]\n",
        "        prompt_embeds = prompt_embeds.hidden_states[-2]\n",
        "        prompt_embeds_list.append(prompt_embeds)\n",
        "\n",
        "    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n",
        "\n",
        "    if negative_prompt is None:\n",
        "        negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n",
        "        negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n",
        "    else:\n",
        "        negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n",
        "\n",
        "        negative_prompt_embeds_list = []\n",
        "        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n",
        "            uncond_input = tokenizer(\n",
        "                negative_prompt,\n",
        "                padding=\"max_length\",\n",
        "                max_length=tokenizer.model_max_length,\n",
        "                truncation=True,\n",
        "                return_tensors=\"pt\",\n",
        "            )\n",
        "\n",
        "            negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True)\n",
        "            negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n",
        "            negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n",
        "            negative_prompt_embeds_list.append(negative_prompt_embeds)\n",
        "\n",
        "        negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n",
        "\n",
        "    return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds"
      ],
      "metadata": {
        "id": "glaz8SYOBdG9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "tokenizers = [tokenizer, tokenizer_2]\n",
        "text_encoders = [text_encoder, text_encoder_2]\n",
        "\n",
        "(\n",
        "    prompt_embeds,\n",
        "    negative_prompt_embeds,\n",
        "    pooled_prompt_embeds,\n",
        "    negative_pooled_prompt_embeds\n",
        ") = encode_prompt(tokenizers, text_encoders, prompt)"
      ],
      "metadata": {
        "id": "Kq2mquG4BnuV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del text_encoder, text_encoder_2, tokenizer, tokenizer_2\n",
        "flush()"
      ],
      "metadata": {
        "id": "rMKZs2CiVwUT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\",\n",
        "    text_encoder=None,\n",
        "    text_encoder_2=None,\n",
        "    tokenizer=None,\n",
        "    tokenizer_2=None,\n",
        "    torch_dtype=torch.float16,\n",
        ")\n",
        "pipe = pipe.to(\"cuda\")\n",
        "\n",
        "timeit(\n",
        "    pipe,\n",
        "    prompt_embeds,\n",
        "    negative_prompt_embeds,\n",
        "    pooled_prompt_embeds,\n",
        "    negative_pooled_prompt_embeds,\n",
        ")"
      ],
      "metadata": {
        "id": "Z5Nr3sNCsfG6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "del pipe\n",
        "flush()"
      ],
      "metadata": {
        "id": "Rpj9XGYZVrR-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Default + Tiny Autoencoder\n",
        "\n",
        "This is better suited for generating (almost) instant previews. The \"instant\" part is of course, GPU-dependent. On an A10G, for example, it can be achieved."
      ],
      "metadata": {
        "id": "ZgGSeJ79hOF4"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from diffusers import AutoencoderTiny\n",
        "\n",
        "pipe = StableDiffusionXLPipeline.from_pretrained(\n",
        "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n",
        ")\n",
        "pipe.vae = AutoencoderTiny.from_pretrained(\"madebyollin/taesdxl\", torch_dtype=torch.float16)\n",
        "pipe = pipe.to(\"cuda\")\n",
        "\n",
        "timeit(pipe)"
      ],
      "metadata": {
        "id": "RU4qIszCefNK"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}